{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `pyro.contrib.funsor`, a new backend for Pyro - Building inference algorithms (Part 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "import functools\n",
    "\n",
    "import torch\n",
    "from torch.distributions import constraints\n",
    "\n",
    "import funsor\n",
    "\n",
    "from pyro import set_rng_seed as pyro_set_rng_seed\n",
    "from pyro.ops.indexing import Vindex\n",
    "from pyro.poutine.messenger import Messenger\n",
    "\n",
    "funsor.set_backend(\"torch\")\n",
    "torch.set_default_dtype(torch.float32)\n",
    "pyro_set_rng_seed(101)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "In part 1 of this tutorial, we were introduced to the new `pyro.contrib.funsor` backend for Pyro.\n",
    "\n",
    "Here we'll look at how to use the components in `pyro.contrib.funsor` to implement a variable elimination inference algorithm from scratch. This tutorial assumes readers are familiar with enumeration-based inference algorithms in Pyro. For background and motivation, readers should consult the [enumeration tutorial](http://pyro.ai/examples/enumeration.html).\n",
    "\n",
    "As before, we'll use `pyroapi` so that we can write our model with standard Pyro syntax."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyro.contrib.funsor\n",
    "import pyroapi\n",
    "from pyroapi import infer, handlers, ops, optim, pyro\n",
    "from pyroapi import distributions as dist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will be working with the following model throughout. It is a discrete-state continuous-observation hidden Markov model with learnable transition and emission distributions that depend on a global random variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = [torch.tensor(1.)] * 10\n",
    "\n",
    "def model(data, verbose):\n",
    "\n",
    "    p = pyro.param(\"probs\", lambda: torch.rand((3, 3)), constraint=constraints.simplex)\n",
    "    locs_mean = pyro.param(\"locs_mean\", lambda: torch.ones((3,)))\n",
    "    locs = pyro.sample(\"locs\", dist.Normal(locs_mean, 1.).to_event(1))\n",
    "    if verbose:\n",
    "        print(\"locs.shape = {}\".format(locs.shape))\n",
    "\n",
    "    x = 0\n",
    "    for i in pyro.markov(range(len(data))):\n",
    "        x = pyro.sample(\"x{}\".format(i), dist.Categorical(p[x]), infer={\"enumerate\": \"parallel\"})\n",
    "        if verbose:\n",
    "            print(\"x{}.shape = \".format(i), x.shape)\n",
    "        pyro.sample(\"y{}\".format(i), dist.Normal(Vindex(locs)[..., x], 1.), obs=data[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can run `model` under the default Pyro backend and the new `contrib.funsor` backend with `pyroapi`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "locs.shape = torch.Size([3])\n",
      "x0.shape =  torch.Size([])\n",
      "x1.shape =  torch.Size([])\n",
      "x2.shape =  torch.Size([])\n",
      "x3.shape =  torch.Size([])\n",
      "x4.shape =  torch.Size([])\n",
      "x5.shape =  torch.Size([])\n",
      "x6.shape =  torch.Size([])\n",
      "x7.shape =  torch.Size([])\n",
      "x8.shape =  torch.Size([])\n",
      "x9.shape =  torch.Size([])\n",
      "locs.shape = torch.Size([3])\n",
      "x0.shape =  torch.Size([])\n",
      "x1.shape =  torch.Size([])\n",
      "x2.shape =  torch.Size([])\n",
      "x3.shape =  torch.Size([])\n",
      "x4.shape =  torch.Size([])\n",
      "x5.shape =  torch.Size([])\n",
      "x6.shape =  torch.Size([])\n",
      "x7.shape =  torch.Size([])\n",
      "x8.shape =  torch.Size([])\n",
      "x9.shape =  torch.Size([])\n"
     ]
    }
   ],
   "source": [
    "# default backend: \"pyro\"\n",
    "with pyroapi.pyro_backend(\"pyro\"):\n",
    "    model(data, verbose=True)\n",
    "    \n",
    "# new backend: \"contrib.funsor\"\n",
    "with pyroapi.pyro_backend(\"contrib.funsor\"):\n",
    "    model(data, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Enumerating discrete variables\n",
    "\n",
    "Our first step is to implement an effect handler that performs parallel enumeration of discrete latent variables. Here we will implement a stripped-down version of `pyro.poutine.enum`, the effect handler behind Pyro's most powerful general-purpose inference algorithms `pyro.infer.TraceEnum_ELBO` and `pyro.infer.mcmc.HMC`.\n",
    "\n",
    "We'll do that by constructing a `funsor.Tensor` representing the support of each discrete latent variable and using the new `pyro.to_data` primitive from part 1 to convert it to a `torch.Tensor` with the appropriate shape."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger\n",
    "\n",
    "class EnumMessenger(NamedMessenger):\n",
    "    \n",
    "    @pyroapi.pyro_backend(\"contrib.funsor\")  # necessary since we invoke pyro.to_data and pyro.to_funsor\n",
    "    def _pyro_sample(self, msg):\n",
    "        if msg[\"done\"] or msg[\"is_observed\"] or msg[\"infer\"].get(\"enumerate\") != \"parallel\":\n",
    "            return\n",
    "\n",
    "        # We first compute a raw value using the standard enumerate_support method.\n",
    "        # enumerate_support returns a value of shape:\n",
    "        #     (support_size,) + (1,) * len(msg[\"fn\"].batch_shape).\n",
    "        raw_value = msg[\"fn\"].enumerate_support(expand=False)\n",
    "        \n",
    "        # Next we'll use pyro.to_funsor to indicate that this dimension is fresh.\n",
    "        # This is guaranteed because we use msg['name'], the name of this pyro.sample site,\n",
    "        # as the name for this positional dimension, and sample site names must be unique.\n",
    "        funsor_value = pyro.to_funsor(\n",
    "            raw_value,\n",
    "            output=funsor.Bint[raw_value.shape[0]],\n",
    "            dim_to_name={-raw_value.dim(): msg[\"name\"]},\n",
    "        )\n",
    "\n",
    "        # Finally, we convert the value back to a PyTorch tensor with to_data,\n",
    "        # which has the effect of reshaping and possibly permuting dimensions of raw_value.\n",
    "        # Applying to_funsor and to_data in this way guarantees that\n",
    "        # each enumerated random variable gets a unique fresh positional dimension\n",
    "        # and that we can convert the model's log-probability tensors to funsor.Tensors\n",
    "        # in a globally consistent manner.\n",
    "        msg[\"value\"] = pyro.to_data(funsor_value)\n",
    "        msg[\"done\"] = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because this is an introductory tutorial, this implementation of `EnumMessenger` works directly with the site's PyTorch distribution since users familiar with PyTorch and Pyro may find it easier to understand. However, when using `contrib.funsor` to implement an inference algorithm in a more realistic setting, it is usually preferable to do as much computation as possible on funsors, as this tends to simplify complex indexing, broadcasting or shape manipulation logic.\n",
    "\n",
    "For example, in `EnumMessenger`, we might instead call `pyro.to_funsor` on `msg[\"fn\"]`:\n",
    "```py\n",
    "funsor_dist = pyro.to_funsor(msg[\"fn\"], output=funsor.Real)(value=msg[\"name\"])\n",
    "# enumerate_support defined whenever isinstance(funsor_dist, funsor.distribution.Distribution)\n",
    "funsor_value = funsor_dist.enumerate_support(expand=False)\n",
    "raw_value = pyro.to_data(funsor_value)\n",
    "```\n",
    "Most of the more complete inference algorithms implemented in `pyro.contrib.funsor` follow this pattern, and we will see an example later in this tutorial. Before we continue, let's see what effect `EnumMessenger` has on the shapes of random variables in our model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "locs.shape = torch.Size([3])\n",
      "x0.shape =  torch.Size([3, 1, 1, 1, 1])\n",
      "x1.shape =  torch.Size([3, 1, 1, 1, 1, 1])\n",
      "x2.shape =  torch.Size([3, 1, 1, 1, 1])\n",
      "x3.shape =  torch.Size([3, 1, 1, 1, 1, 1])\n",
      "x4.shape =  torch.Size([3, 1, 1, 1, 1])\n",
      "x5.shape =  torch.Size([3, 1, 1, 1, 1, 1])\n",
      "x6.shape =  torch.Size([3, 1, 1, 1, 1])\n",
      "x7.shape =  torch.Size([3, 1, 1, 1, 1, 1])\n",
      "x8.shape =  torch.Size([3, 1, 1, 1, 1])\n",
      "x9.shape =  torch.Size([3, 1, 1, 1, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "with pyroapi.pyro_backend(\"contrib.funsor\"), \\\n",
    "        EnumMessenger():\n",
    "    model(data, True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Vectorizing a model across multiple samples\n",
    "\n",
    "Next, since our priors over global variables are continuous and cannot be enumerated exactly, we will implement an effect handler that uses a global dimension to draw multiple samples in parallel from the model. Our implementation will allocate a new particle dimension using `pyro.to_data` as in `EnumMessenger` above, but unlike the enumeration dimensions, we want the particle dimension to be shared across all sample sites, so we will mark it as a `DimType.GLOBAL` dimension when invoking `pyro.to_funsor`.\n",
    "\n",
    "Recall that in part 1 we saw that `DimType.GLOBAL` dimensions must be deallocated manually or they will persist until the final effect handler has exited. This low-level detail is taken care of automatically by the `GlobalNameMessenger` handler provided in `pyro.contrib.funsor` as a base class for any effect handlers that allocate global dimensions. Our vectorization effect handler will inherit from this class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyro.contrib.funsor.handlers.named_messenger import GlobalNamedMessenger\n",
    "from pyro.contrib.funsor.handlers.runtime import DimRequest, DimType\n",
    "\n",
    "class VectorizeMessenger(GlobalNamedMessenger):\n",
    "    \n",
    "    def __init__(self, size, name=\"_PARTICLES\"):\n",
    "        super().__init__()\n",
    "        self.name = name\n",
    "        self.size = size\n",
    "\n",
    "    @pyroapi.pyro_backend(\"contrib.funsor\")\n",
    "    def _pyro_sample(self, msg):\n",
    "        if msg[\"is_observed\"] or msg[\"done\"] or msg[\"infer\"].get(\"enumerate\") == \"parallel\":\n",
    "            return\n",
    "        \n",
    "        # we'll first draw a raw batch of samples similarly to EnumMessenger.\n",
    "        # However, since we are drawing a single batch from the joint distribution,\n",
    "        # we don't need to take multiple samples if the site is already batched.\n",
    "        if self.name in pyro.to_funsor(msg[\"fn\"], funsor.Real).inputs:\n",
    "            raw_value = msg[\"fn\"].rsample()\n",
    "        else:\n",
    "            raw_value = msg[\"fn\"].rsample(sample_shape=(self.size,))\n",
    "        \n",
    "        # As before, we'll use pyro.to_funsor to register the new dimension.\n",
    "        # This time, we indicate that the particle dimension should be treated as a global dimension.\n",
    "        fresh_dim = len(msg[\"fn\"].event_shape) - raw_value.dim()\n",
    "        funsor_value = pyro.to_funsor(\n",
    "            raw_value,\n",
    "            output=funsor.Reals[tuple(msg[\"fn\"].event_shape)],\n",
    "            dim_to_name={fresh_dim: DimRequest(value=self.name, dim_type=DimType.GLOBAL)},\n",
    "        )\n",
    "        \n",
    "        # finally, convert the sample to a PyTorch tensor using to_data as before\n",
    "        msg[\"value\"] = pyro.to_data(funsor_value)\n",
    "        msg[\"done\"] = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see what effect `VectorizeMessenger` has on the shapes of the values in `model`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "locs.shape = torch.Size([10, 1, 1, 1, 1, 3])\n",
      "x0.shape =  torch.Size([])\n",
      "x1.shape =  torch.Size([])\n",
      "x2.shape =  torch.Size([])\n",
      "x3.shape =  torch.Size([])\n",
      "x4.shape =  torch.Size([])\n",
      "x5.shape =  torch.Size([])\n",
      "x6.shape =  torch.Size([])\n",
      "x7.shape =  torch.Size([])\n",
      "x8.shape =  torch.Size([])\n",
      "x9.shape =  torch.Size([])\n"
     ]
    }
   ],
   "source": [
    "with pyroapi.pyro_backend(\"contrib.funsor\"), \\\n",
    "        VectorizeMessenger(size=10):\n",
    "    model(data, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now in combination with `EnumMessenger`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "locs.shape = torch.Size([10, 1, 1, 1, 1, 3])\n",
      "x0.shape =  torch.Size([3, 1, 1, 1, 1, 1])\n",
      "x1.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])\n",
      "x2.shape =  torch.Size([3, 1, 1, 1, 1, 1])\n",
      "x3.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])\n",
      "x4.shape =  torch.Size([3, 1, 1, 1, 1, 1])\n",
      "x5.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])\n",
      "x6.shape =  torch.Size([3, 1, 1, 1, 1, 1])\n",
      "x7.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])\n",
      "x8.shape =  torch.Size([3, 1, 1, 1, 1, 1])\n",
      "x9.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "with pyroapi.pyro_backend(\"contrib.funsor\"), \\\n",
    "        VectorizeMessenger(size=10), EnumMessenger():\n",
    "    model(data, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computing an ELBO with variable elimination\n",
    "\n",
    "Now that we have tools for enumerating discrete variables and drawing batches of samples, we can use those to compute quantities of interest for inference algorithms.\n",
    "\n",
    "Most inference algorithms in Pyro work with `pyro.poutine.Trace`s, custom data structures that contain parameters and sample site distributions and values and all of the associated metadata needed for inference computations. Our third effect handler `LogJointMessenger` departs from this design pattern, eliminating a tremendous amount of boilerplate in the process. It will automatically build up a lazy Funsor expression for the logarithm of the joint probability density of a model; when working with `Trace`s, this process must be triggered manually by calling `Trace.compute_log_probs()` and eagerly computing an objective from the resulting individual log-probability tensors in the trace.\n",
    "\n",
    "In our implementation of `LogJointMessenger`, unlike the previous two effect handlers, we will call `pyro.to_funsor` on both the sample value and the distribution to show how nearly all inference operations including log-probability density evaluation can be performed on `funsor.Funsor`s directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogJointMessenger(Messenger):\n",
    "\n",
    "    def __enter__(self):\n",
    "        self.log_joint = funsor.Number(0.)\n",
    "        return super().__enter__()\n",
    "\n",
    "    @pyroapi.pyro_backend(\"contrib.funsor\")\n",
    "    def _pyro_post_sample(self, msg):\n",
    "        \n",
    "        # for Monte Carlo-sampled variables, we don't include a log-density term:\n",
    "        if not msg[\"is_observed\"] and not msg[\"infer\"].get(\"enumerate\"):\n",
    "            return\n",
    "        \n",
    "        with funsor.interpreter.interpretation(funsor.terms.lazy):\n",
    "            funsor_dist = pyro.to_funsor(msg[\"fn\"], output=funsor.Real)\n",
    "            funsor_value = pyro.to_funsor(msg[\"value\"], output=funsor_dist.inputs[\"value\"])\n",
    "            self.log_joint += funsor_dist(value=funsor_value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And finally the actual loss function, which applies our three effect handlers to compute an expression for the log-density, marginalizes over discrete variables with `funsor.ops.logaddexp`, averages over Monte Carlo samples with `funsor.ops.add`, and evaluates the final lazy expression using Funsor's `optimize` interpretation for variable elimination.\n",
    "\n",
    "Note that `log_z` exactly collapses the model's local discrete latent variables but is an ELBO wrt any continuous latent variables, and is thus equivalent to a simple version of `TraceEnum_ELBO` with an empty guide."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "@pyroapi.pyro_backend(\"contrib.funsor\")\n",
    "def log_z(model, model_args, size=10):\n",
    "    with LogJointMessenger() as tr, \\\n",
    "            VectorizeMessenger(size=size) as v, \\\n",
    "            EnumMessenger():\n",
    "        model(*model_args)\n",
    "\n",
    "    with funsor.interpreter.interpretation(funsor.terms.lazy):\n",
    "        prod_vars = frozenset({v.name})\n",
    "        sum_vars = frozenset(tr.log_joint.inputs) - prod_vars\n",
    "        \n",
    "        # sum over the discrete random variables we enumerated\n",
    "        expr = tr.log_joint.reduce(funsor.ops.logaddexp, sum_vars)\n",
    "        \n",
    "        # average over the sample dimension\n",
    "        expr = expr.reduce(funsor.ops.add, prod_vars) - funsor.Number(float(size))\n",
    "\n",
    "    return pyro.to_data(funsor.optimizer.apply_optimizer(expr))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Putting it all together\n",
    "\n",
    "Finally, with all this machinery implemented, we can compute stochastic gradients wrt the ELBO."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-133.6274, grad_fn=<AddBackward0>)\n",
      "tensor(-129.2379, grad_fn=<AddBackward0>)\n",
      "tensor(-125.9609, grad_fn=<AddBackward0>)\n",
      "tensor(-123.7484, grad_fn=<AddBackward0>)\n",
      "tensor(-122.3034, grad_fn=<AddBackward0>)\n"
     ]
    }
   ],
   "source": [
    "with pyroapi.pyro_backend(\"contrib.funsor\"):\n",
    "    model(data, verbose=False)  # initialize parameters\n",
    "    params = [pyro.param(\"probs\").unconstrained(), pyro.param(\"locs_mean\").unconstrained()]\n",
    "\n",
    "optimizer = torch.optim.Adam(params, lr=0.1)\n",
    "for step in range(5):\n",
    "    optimizer.zero_grad()\n",
    "    log_marginal = log_z(model, (data, False))\n",
    "    (-log_marginal).backward()\n",
    "    optimizer.step()\n",
    "    print(log_marginal)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
